package com.borealis.common.utils.lock;

import java.util.HashMap;
import java.util.Map;
import java.util.concurrent.atomic.AtomicInteger;
import java.util.concurrent.atomic.AtomicLong;

/**
 * ID锁抽象类
 *
 * @author yaoweixin
 * @date 2018/10/15
 */
public abstract class AbstractIdLock {

  public interface ILockable {

    /**
     * 获得锁
     */
    void lock();

    /**
     * 返回获得锁。竞争锁时，没竞争到的直接返回false
     * @return true|false
     */
    boolean returnLock();

    /**
     * 释放锁
     */
    void unlock();

    /**
     * 获取当前锁数量
     *
     * @return 数量
     */
    int getLockCount();
  }

  ThreadLocal<Long> threadId = ThreadLocal.withInitial(AbstractIdLock.this::getThreadId);

  /**
   * 获得线程编号
   *
   * @return 线程编号
   */
  protected abstract Long getThreadId();

  /**
   * 获得锁
   *
   * @param res      资源
   * @param threadId 线程编号
   */
  protected abstract void lock(String res, Long threadId);

  /**
   * 获得锁
   *
   * @param res      资源
   * @param threadId 线程编号
   * @return true|false
   */
  protected abstract boolean returnLock(String res, Long threadId);

  /**
   * 释放锁
   *
   * @param res      资源
   * @param threadId 线程编号
   */
  protected abstract void unlock(String res, Long threadId);

  static class LockItem implements ILockable {

    AbstractIdLock manager;
    String res;
    AtomicInteger lockCount = new AtomicInteger(0);
    AtomicLong lockThreadId = new AtomicLong(0);

    LockItem(AbstractIdLock manager, String res) {
      this.manager = manager;
      this.res = res;
    }

    @Override
    public void lock() {
      long threadId = manager.threadId.get();
      if (lockCount.get() > 0 && lockThreadId.get() == threadId) {
        lockCount.incrementAndGet();
        return;
      }
      manager.lock(res, threadId);
      lockThreadId.set(threadId);
      lockCount.set(1);
    }

    @Override
    public boolean returnLock() {
      long threadId = manager.threadId.get();
      if (lockCount.get() > 0 && lockThreadId.get() == threadId) {
        lockCount.incrementAndGet();
        return true;
      }
      boolean flag = manager.returnLock(res, threadId);
      if (flag) {
        lockThreadId.set(threadId);
        lockCount.set(1);
      }
      return flag;
    }

    @Override
    public void unlock() {
      long threadId = manager.threadId.get();
      if (threadId == lockThreadId.get()) {
        if (lockCount.get() > 1) {
          lockCount.decrementAndGet();
          return;
        }
        assert lockCount.get() == 1;
        lockCount.set(0);
        lockThreadId.compareAndSet(threadId, 0);
        manager.unlock(res, threadId);
      }
    }

    @Override
    public int getLockCount() {
      return lockCount.get();
    }
  }

  ThreadLocal<Map<String, ILockable>> localLocks = ThreadLocal.withInitial(HashMap::new);

  public void lock(String... ids) {
    Map<String, ILockable> localLocks = this.localLocks.get();
    for (String id : ids) {
      if (!localLocks.containsKey(id)) {
        ILockable lock = new LockItem(this, id);
        lock.lock();
        localLocks.put(id, lock);
      }
    }
  }

  public boolean returnLock(String id) {
    boolean flag = false;
    Map<String, ILockable> localLocks = this.localLocks.get();
    if (!localLocks.containsKey(id)) {
      ILockable lock = new LockItem(this, id);
      flag = lock.returnLock();
      localLocks.put(id, lock);
    }
    return flag;
  }

  public void unlock(String... ids) {
    Map<String, ILockable> localLocks = this.localLocks.get();
    for (String id : ids) {
      if (localLocks.containsKey(id)) {
        ILockable lock = localLocks.get(id);
        lock.unlock();
      }
    }
    for (String id : ids) {
      localLocks.remove(id);
    }
  }

  public void unlockAll() {
    Map<String, ILockable> localLocks = this.localLocks.get();
    for (ILockable lock : localLocks.values()) {
      lock.unlock();
    }
    localLocks.clear();
  }
}
